今天我們來實際來跑簡單的Dataset,就是 DL 101 資料集 - MNIST。透過較為簡單的Dataset 來理解像GAN這種相對難的演算法,應該能較容易理解GAN!
首先,可以直接透過 tf 的 api 來 load MNIST 資料集:
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
這邊的部分,我們就先拿train 的資料集,畢竟是Unsupervised。
接下來可以做簡單的前處理,並轉成tf 的格式:
x_train = x_train.astype(np.float32) / 255.
train_data = tf.data.Dataset.from_tensor_slices(x_train).shuffle(batch_size*6).batch(batch_size).repeat()
train_data_iter = iter(train_data)
資料處理完後,我們可以先來定義 Generator
以及 Discriminiator
。這部分的話看可以看前一天的部分,前一天有完整的 Model define Generator 以及 Discriminitator。這邊的話,就簡單測試一下Model 產生的 樣態
g = Generator()
d = Discriminator()
noise = tf.random.normal([1, 100])
generated_image = g(noise, training=False)
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
以及
decision = d(generated_image)
print (decision)
可以從中簡單看出 Generator Gen出來的圖片,以及 當 Discriminator 判斷出來的 (未train)
接下來就可以簡單的做Model summary了
generator = Generator()
generator.build(input_shape=(batch_size, z_dim))
generator.summary()
discriminator = Discriminator()
discriminator.build(input_shape=(batch_size, 28, 28, 1))
discriminator.summary()
Model summary後,我們要來定義GAN的Loss,而 Generator 跟 Discriminator 的 loss是不太一樣的。
首先我們先從 Discriminator 開始:
Discriminator的部分,主要概念就是比較 Generator 所產生的假圖片跟真圖片來做比較!
def dis_loss(generator, discriminator, input_noise, real_image, is_trainig):
fake_image = generator(input_noise, is_trainig)
d_real_logits = discriminator(real_image, is_trainig)
d_fake_logits = discriminator(fake_image, is_trainig)
d_loss_real = loss_real(d_real_logits)
d_loss_fake = loss_fake(d_fake_logits)
loss = d_loss_real + d_loss_fake
return loss
Generator的部分,就是透過 noise 產生圖片後,想辦法去騙過 Discriminator 的結果 (fake_loss)。
def gen_loss(generator, discriminator, input_noise, is_trainig):
fake_image = generator(input_noise, is_trainig)
fake_loss = discriminator(fake_image, is_trainig)
loss = loss_real(fake_loss)
return loss
而其中 loss_real
以及 loss_fake
就是透過 Discriminator output 的loss去計算 tf.nn.sigmoid_cross_entropy
。 比較有差距的就是 labels 的部分,一個是for 真實圖片的loss (label 為1) 一個是for 假圖片的loss (label為0)
def loss_real(logits):
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=tf.ones_like(logits)))
def loss_fake(logits):
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=tf.zeros_like(logits)))
定義完loss後,我們可以來寫train的period,記得 Generator 跟 Discriminator 的 tf.GradientTape()
要分開寫!然後weight 也是分開 update。
for epoch in range(epochs):
batch_x = next(train_data_iter)
batch_x = tf.reshape(batch_x, shape=inputs_shape)
batch_x = batch_x * 2.0 - 1.0
batch_z = tf.random.normal(shape=[batch_size, z_dim])
with tf.GradientTape() as tape:
d_loss = dis_loss(generator, discriminator, batch_z, batch_x, is_training)
grads = tape.gradient(d_loss, discriminator.trainable_variables)
d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
with tf.GradientTape() as tape:
g_loss = gen_loss(generator, discriminator, batch_z, is_training)
grads = tape.gradient(g_loss, generator.trainable_variables)
g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
if epoch % 100 == 0:
print(epoch, 'd loss:', float(d_loss), 'g loss:', float(g_loss))
最後我們就來隨機 gen一張圖
是不是有點看起來像是手寫數字了 XD
GAN真的是一個蠻好玩的演算法,有興趣大家可以在網路上找資源!其時候多人已經最好很棒的transfer learning for GAN 。 Ex: 圖像風格轉變 等等。謝謝大家今天漫長的閱讀 ~ 明天是最後一天連假,祝福大家明天最後一天連假愉快!